#
#  Copyright 2019 The FATE Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
import os
import typing

import click

if typing.TYPE_CHECKING:
    from fate.components.core.spec.task import TaskConfigSpec


@click.command()
@click.option("--process-tag", required=False, help="unique id to identify this execution process")
@click.option("--config", required=False, type=click.File(), help="config path")
@click.option("--config-entrypoint", required=False, help="entrypoint to get config")
@click.option("--properties", "-p", multiple=True, help="properties config")
@click.option("--env-prefix", "-e", type=str, default="runtime.component_desc.", help="prefix for env config")
@click.option("--env-name", required=False, type=str, help="env name for config")
@click.option(
    "--execution-final-meta-path",
    type=click.Path(exists=False, dir_okay=False, writable=True, resolve_path=True),
    default=os.path.join(os.getcwd(), "execution_final_meta.json"),
    show_default=True,
    help="path for execution meta generated by component when execution finished",
)
@click.option("--debug", is_flag=True, help="enable debug mode")
def execute(
    process_tag, config, config_entrypoint, properties, env_prefix, env_name, execution_final_meta_path, debug
):
    """
    execute component
    """
    import logging

    from fate.components.core.spec.task import TaskConfigSpec
    from fate.components.entrypoint.utils import (
        load_config_from_entrypoint,
        load_config_from_env,
        load_config_from_file,
        load_config_from_properties,
        load_properties,
        load_properties_from_env,
    )

    "execute component_desc"
    if config is None and config_entrypoint is None and not properties and env_name is None:
        raise click.UsageError("at least one of config, config-entrypoint, properties, env-name should be provided")

    # parse properties
    properties_items = {}
    properties_items.update(load_properties(properties))
    properties_items.update(load_properties_from_env(env_prefix))

    # parse config
    configs = {}
    load_config_from_env(configs, env_name)
    load_config_from_entrypoint(configs, config_entrypoint)
    load_config_from_file(configs, config)
    load_config_from_properties(configs, properties_items)

    task_config = TaskConfigSpec.parse_obj(configs)

    # install logger
    task_config.conf.logger.install(debug=debug)
    logger = logging.getLogger(__name__)
    logger.debug("logger installed")
    logger.debug(f"task config: {task_config}")

    os.makedirs(os.path.dirname(execution_final_meta_path), exist_ok=True)

    from fate.arch.config import cfg

    if task_config.component in cfg.components.blacklist:
        raise RuntimeError(f"component `{task_config.component}` is in blacklist, do not use it")

    execute_component_from_config(task_config, execution_final_meta_path)


def execute_component_from_config(config: "TaskConfigSpec", output_path):
    import json
    import logging
    import traceback

    from fate.arch import CipherKit, Context
    from fate.arch.trace import profile_ends, profile_start
    from fate.components.core import (
        ComponentExecutionIO,
        Role,
        Stage,
        load_component,
        load_computing,
        load_device,
        load_federation,
        load_metric_handler,
        is_root_worker,
    )

    logger = logging.getLogger(__name__)
    logger.debug(f"logging final status to  `{output_path}`")
    try:
        party_task_id = config.party_task_id
        device = load_device(config.conf.device)
        computing = load_computing(config.conf.computing, config.conf.logger.config)
        if is_root_worker():
            federation = load_federation(config.conf.federation, computing)
        else:
            federation = None
            logger.info("skip federation initialization for non-root worker")
        cipher = CipherKit(device=device)
        ctx = Context(
            device=device,
            computing=computing,
            federation=federation,
            cipher=cipher,
        )
        role = Role.from_str(config.role)
        stage = Stage.from_str(config.stage)
        logger.debug(f"component={config.component}, context={ctx}")
        logger.debug("running...")

        # get correct component_desc/subcomponent handle stage
        component = load_component(config.component, stage)

        # enable profiling
        profile_start()

        # prepare
        execution_io = ComponentExecutionIO(ctx, component, role, stage, config)

        # register metric handler
        metrics_handler = load_metric_handler(execution_io.get_metric_writer())
        ctx.set_metric_handler(metrics_handler)

        # execute
        component.execute(ctx, role, **execution_io.get_kwargs())

        # finalize metric handler
        metrics_handler.finalize()
        # final execution io meta
        execution_io_meta = execution_io.dump_io_meta()
        try:
            with open(output_path, "w") as fw:
                json.dump(dict(status=dict(code=0), io_meta=execution_io_meta), fw, indent=4)
        except Exception as e:
            raise RuntimeError(f"failed to dump execution io meta to `{output_path}`: meta={execution_io_meta}") from e

        profile_ends()
        logger.debug("done without error, waiting signal to terminate")
        logger.debug("terminating, bye~")

    except Exception as e:
        logger.error(e, exc_info=True)
        with open(output_path, "w") as fw:
            json.dump(dict(status=dict(code=-1, exceptions=traceback.format_exc())), fw)
        raise e


if __name__ == "__main__":
    execute()
